import os
import shutil
from tqdm import tqdm

root = '/mnt/data1/dzy/datas/CDC_max_face_size'
train_dir = '/mnt/data1/dzy/datas/CDC_max_face_size/train'
test_dir = '/mnt/data1/dzy/datas/CDC_max_face_size/test'
train_rate = 0.95

os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

img_dirs = [os.path.join(root, x) for x in os.listdir(root) if x.startswith('头大')]
all_imgs = {path:os.listdir(path) for path in img_dirs}

for k,v in all_imgs.items():
    print(k, len(v))

for k,v in all_imgs.items():
    train_num = int(len(v)*train_rate)
    train_imgs = v[:train_num]
    test_imgs = v[train_num:]
    for img in tqdm(train_imgs):
        shutil.copy(os.path.join(k, img), os.path.join(train_dir, img))
    for img in test_imgs:
        shutil.copy(os.path.join(k, img), os.path.join(test_dir, img))
    print(k, len(train_imgs), len(test_imgs))